from abc import ABC, abstractmethod

import torch.nn as nn
import numpy as np

from knowledge_tracing.args import ARGS
import knowledge_tracing.network.transformer_layer as layers
from knowledge_tracing.network.util_network import shift_right


class Feature(ABC):
    '''
      Abstract class for feature class
      All features are required to implement the following methods
    '''
    @abstractmethod
    def tensorize(self, interactions, seq_size):
        pass

    @abstractmethod
    def preprocess_enc(self, input_tensor):
        pass

    @abstractmethod
    def preprocess_dec(self, input_tensor):
        pass

    @abstractmethod
    def embed_layer(self, embed_size):
        pass


# num_categories counts start_token but does not count padding
class CategoricalFeature(Feature):
    def __init__(self, num_categories, name, padding, start_token):
        self.num_categories = num_categories
        self.padding = padding
        self.name = name
        self.start_token = start_token

    # interactions: list of Interactions
    # generic tensorize().
    def tensorize(self, interactions, seq_size):
        values = [inter.__dict__[self.name] for inter in interactions]
        values += [self.padding] * (seq_size - len(values))
        return np.array(values, dtype=np.long)

    # input tensor: shape [batch_size, sequence_size]
    def preprocess_enc(self, input_tensor):
        return input_tensor

    def preprocess_dec(self, input_tensor):
        if self.start_token is not None:
            return shift_right(input_tensor, 1, pad_value=self.start_token, dim=1)
        else:
            return input_tensor

    def embed_layer(self, embed_size):
        return layers.EmbeddingLayer(input_size=self.num_categories,
                                     embed_size=embed_size,
                                     padding_idx=self.padding)


class IsCorrect(CategoricalFeature):
    def __init__(self, name):
        super().__init__(num_categories=3,  # correct -> 1, wrong -> 2, start_token -> 3
                         name=name, padding=0, start_token=3)

    def tensorize(self, inters, seq_size):
        # correct -> 1, wrong -> 2, start_token -> 3
        def _inter_to_index(i):
            return 2 - i.is_correct if i.is_correct is not None else 0
        values = list(map(_inter_to_index, inters))
        values += [self.padding] * (seq_size - len(values))
        return np.array(values, dtype=np.long)


class NumericalFeature(Feature):
    def __init__(self, num_dims, name, padding, start_token):
        self.num_dims = num_dims
        self.padding = padding
        self.name = name
        self.start_token = start_token

    # generic tensorize().
    def tensorize(self, interactions, seq_size):
        if self.num_dims == 1:
            values = [[inter.__dict__[self.name]] for inter in interactions]
        else:
            raise NotImplementedError
        values += [[self.padding] * self.num_dims] * (seq_size - len(values))
        return np.array(values, dtype=np.float32)

    def preprocess_enc(self, input_tensor):
        return input_tensor

    def preprocess_dec(self, input_tensor):
        if self.start_token is None:
            return input_tensor
        else:
            return shift_right(input_tensor, 1, pad_value=self.start_token, dim=1)

    def embed_layer(self, embed_size):
        return nn.Linear(self.num_dims, embed_size)


class LagTime(NumericalFeature):
    def __init__(self, name):
        super().__init__(1, name, padding=0.0, start_token=-1.0)

    def tensorize(self, inters, seq_size):
        start_times = [i.start_time for i in inters]
        time_diffs = [t1 - t0 for t1, t0 in zip(start_times[1:], start_times[:-1])]
        time_diffs = [min(t, ARGS.max_lag_time) / ARGS.max_lag_time for t in time_diffs]
        time_diffs = [-1.0] + time_diffs
        time_diffs += [self.padding] * (seq_size - len(time_diffs))
        return np.expand_dims(np.array(time_diffs, dtype=np.float32), axis=-1)


class ElapsedTime(NumericalFeature):
    def __init__(self, name):
        super().__init__(1, name, padding=0.0, start_token=-1.0)

    def tensorize(self, inters, seq_size):
        etimes = [min(i.elapsed_time, ARGS.max_elapsed_time) / ARGS.max_elapsed_time for i in inters]
        etimes += [self.padding] * (seq_size - len(etimes))
        return np.expand_dims(np.array(etimes, dtype=np.float32), axis=-1)


class PositionalFeature(Feature):
    def __init__(self, seq_len, name):
        self.seq_len = seq_len
        self.name = name

    def tensorize(self, inters, seq_size):
        return np.arange(1, self.seq_len + 1)  # 1, 2, 3, ....,

    def preprocess_enc(self, input_tensor):
        return input_tensor

    def preprocess_dec(self, input_tensor):
        return input_tensor

    def embed_layer(self, embed_size):
        return layers.EmbeddingLayer(input_size=self.seq_len,
                                     embed_size=embed_size,
                                     padding_idx=0)


class CategoricalMultiFeature(Feature):
    def __init__(self,
                 num_categories,
                 max_num_features,
                 name,
                 padding,
                 start_token,
                 collate='average'):
        # Supported collate is ['max', 'sum', 'average'].
        self.num_categories = num_categories
        self.max_num_features = max_num_features
        self.name = name
        self.padding = padding
        self.start_token = start_token
        self.collate = collate

    # generic tensorize().
    def tensorize(self, interactions, seq_size):
        output = []
        for inter in interactions:
            inter_features = inter.__dict__[self.name]
            assert len(inter_features) <= self.max_num_features, 'Number of features more than max_num_features'
            inter_features += [self.padding] * (self.max_num_features - len(inter_features))
            output.append(inter_features)
        output += [[self.padding] * self.max_num_features] * (seq_size - len(output))
        return np.array(output, dtype=np.long)

    # input tensor: shape [batch_size, sequence_size]
    def preprocess_enc(self, input_tensor):
        return input_tensor

    def preprocess_dec(self, input_tensor):
        if self.start_token is not None:
            return shift_right(input_tensor,
                               1,
                               pad_value=self.start_token,
                               dim=1)
        else:
            return input_tensor

    def embed_layer(self, embed_size):
        return layers.MultiEmbeddingLayer(input_size=self.num_categories,
                                          embed_size=embed_size,
                                          padding_idx=self.padding,
                                          collate=self.collate)


class SequenceSize(Feature):
    def __init__(self, name):
        self.name = name

    def tensorize(self, inters, seq_size):
        return np.array([len(inters)], dtype=np.long)

    def preprocess_enc(self, input_tensor):
        assert False

    def preprocess_dec(self, input_tensor):
        assert False

    def embed_layer(self, embed_size):
        assert False


class PaddingMask(Feature):
    def __init__(self, name):
        self.name = name

    def tensorize(self, inters, seq_size):
        value = [True] * len(inters) + [False] * (seq_size - len(inters))
        return np.array(value, dtype=np.bool)

    def preprocess_enc(self, input_tensor):
        assert False

    def preprocess_dec(self, input_tensor):
        assert False

    def embed_layer(self, embed_size):
        assert False


class LossMask(Feature):
    def __init__(self, name):
        self.name = name

    def tensorize(self, inters, seq_size):
        # value = [i.item_id[0] == 'q' for i in inters] # questions only
        value = [True for i in inters]  # all interactions are questions
        value += [False] * (seq_size - len(inters))
        value = np.array(value, dtype=np.bool)
        return value

    def preprocess_enc(self, input_tensor):
        assert False

    def preprocess_dec(self, input_tensor):
        assert False

    def embed_layer(self, embed_size):
        assert False
